Fit the FDA model

Author

Nicholas Harbour

Published

February 10, 2025, 10:31:53

Modified

February 10, 2025, 10:31:23

1 Load in the data

Import python libraries required for the analysis

Code
# Import the necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import datetime as dt
from scipy.special import i0 
from scipy.optimize import curve_fit
from scipy.stats import circvar
from matplotlib.ticker import MaxNLocator

# import functions 
from functions import *

Load in the data and interpolate it.

Code
df = load_data("Data/Hormone_time_series_data.csv")

df_meta = pd.read_csv("Data/Metadata_RedCap.csv")

df_common = common_time(df, interp_limit = 40)

num = df_to_numpy(df_common, "Cortisol")

hormones = df.columns[4:-1]

common_t = df_common["NewTime"].iloc[0:72].values

# NOTE one of the IDs is not in the time series date
# Find identifiers in df["PID"] that are not in df_meta["id_participant"]
not_in_time_series = df_meta[~df_meta['id_participant'].isin(df['PID'])]
df_meta = df_meta[df_meta['id_participant'].isin(df['PID'])]
# reset the index
df_meta = df_meta.reset_index(drop=True)

2 Fit the basis functions

We can write our simple model for the data as a sum of basis functions:

\[ x_i(t) = \sum_{j=1}^k w_{i,j} \phi_j(t - d_j) \]

Where \(x_i(t)\) is the value of the \(i\) th time series at time \(t\). \(w_{i,j}\) is the weight of the \(j\) th basis function for the person \(i\). \(\phi_j(t)\) is the basis function and \(k\) is the number of basis functions.

We take the basis functions to be von Mises distriubtion, which is given by:

\[ f(x | \mu, \kappa) = \frac{e^{\kappa \cos(x - \mu)}}{2\pi I_0(\kappa)} \]

These \(\mu\)’s are equivalent to the shift parameter we wrote down previously.

And to fit our von Mises basis we can do the following:

In the model fit we have the following constraints/bounds:

  • The peak location is between \(-\pi\) and \(\pi\).
  • The amplitude/weight is between 0 and 40. (always positive)
  • The dispersion is between 0 and 200.
Code
def func(x, *params):
    y = np.zeros_like(x)
    for i in range(0, len(params), 3):
        ctr = params[i]
        amp = params[i+1]
        wid = params[i+2]
        y = y + amp*np.exp(wid*np.cos(x-ctr))/(2*np.pi*i0(wid))
    return y

t = np.linspace(-np.pi, np.pi, num=72)

def fit_von_mises(time_series, n_basis):

    '''
    Fit the von Mises basis to a single time series
    '''

    guess = np.zeros(3*n_basis)
    bounds = np.zeros([2, 3*n_basis])
    for k in range(n_basis):

        #guess[3*k] = np.random.uniform(-np.pi, np.pi) # mean location
        guess[3*k] = 0 # mean location
        guess[3*k+1] = 30 # amplitude/wieght
        guess[3*k+2] = 6 #dispersion

        # Peak location
        bounds[0, 3*k] = -np.pi
        bounds[1, 3*k] = np.pi

        # Amplitude (we will enforce this to be positive)
        bounds[0, 3*k + 1] = 0
        bounds[1, 3*k + 1] = 40

        # The dispersion, to give a more peaky distirbution increase the lower bound 'peaky distributions'
        bounds[0, 3*k + 2] = 2
        bounds[1, 3*k + 2] = 200

    popt, pcov = curve_fit(func, t, time_series, p0=guess, maxfev=100000000, bounds = bounds)
    #popt, pcov = curve_fit(func, t, time_series, p0=guess, maxfev=100000000)
    fit = func(t, *popt)

    return  popt

pat_ind = 31

popt = fit_von_mises(num[pat_ind], 3)

fit = func(t, *popt)

plt.plot(t, num[pat_ind], 'bo-')
plt.plot(t, fit , 'r-')
plt.legend(['data', 'model fit'])

residuals = num[pat_ind] - func(t, *popt)
ss_res = np.sum(residuals**2)
# Adding text without box on the plot.
plt.text(0.01, 0.8,  f"R squared = {ss_res:.4}", ha='left', va='top', transform=plt.gca().transAxes)

for i in range(0, len(popt), 3):
    ctr = popt[i]
    amp = popt[i+1]
    wid = popt[i+2]
    y = amp*np.exp(wid*np.cos(t-ctr))/(2*np.pi*i0(wid))
    plt.plot(t, y,'--', alpha = 0.8)

plt.title("Example Von Mises fit")
plt.show()

2.1 Fit 3 basis function to all patients

Lets fit 3 basis function to all patient and see which ones it fits the data well for and identify one that may need to be fit re-fit due to problems with the default optimization.

When we fit the basis function model, the order of the basis functions is arbitrary and can be interchanges without altering the overall model.

We will order them by the weight of the basis function.

Code
def order_basis(basis1, basis2, basis3):

    if basis1[1] > basis2[1] and basis1[1] > basis3[1]:
        if basis3[1] > basis2[1]:
            temp = basis2
            basis2 = basis3
            basis3 = temp
    
    elif basis2[1] > basis1[1] and basis2[1] > basis3[1]:
        temp = basis1
        basis1 = basis2
        basis2 = temp
        if basis3[1] > basis2[1]:
            temp = basis2
            basis2 = basis3
            basis3 = temp
    elif basis3[1] > basis1[1] and basis3[1] > basis2[1]:
        temp = basis1
        basis1 = basis3
        basis3 = temp
        if basis3[1] > basis2[1]:
            temp = basis2
            basis2 = basis3
            basis3 = temp

    return basis1, basis2, basis3
Code
n_basis = 3
ss_res_save = np.zeros(len(num))
basis_save = np.zeros([len(num), 3*n_basis])
basis1 = np.zeros(3)
basis2 = np.zeros(3)
basis3 = np.zeros(3)

for i in range(len(num)):
    popt = fit_von_mises(num[i], n_basis)
    basis1 = popt[0:3]
    basis2 = popt[3:6]
    basis3 = popt[6:9]
    fit = func(t, *popt)
    residuals = num[i,:] - fit
    ss_res = np.sum(residuals**2)
    ss_res_save[i] = ss_res

    # Order the basis functions by the weight (the final parameter)
    basis1, basis2, basis3 = order_basis(basis1, basis2, basis3)

    basis_save[i,0:3] = basis1
    basis_save[i,3:6] = basis2
    basis_save[i,6:9] = basis3

Plot a number of the model fits on a single plot

Code
def plot_single(popt, t, time_series,ax, ss_res = -1):
    ax.plot(t, time_series, 'bo-')
    ax.plot(t, func(t, *popt) , 'r-')
    ax.text(0.01, 0.99, f"$R^2$ = {ss_res:.2f}", ha='left', va='top', 
            transform=ax.transAxes, fontsize=12, color='black')
    
    for i in range(0, len(popt), 3):
        ctr = popt[i]
        amp = popt[i+1]
        wid = popt[i+2]
        y = amp*np.exp(wid*np.cos(t-ctr))/(2*np.pi*i0(wid))
        ax.plot(t, y,'--', alpha = 0.8)


    # Set y-axis ticks to be integers only
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))


rows = 8
cols = 5
fig, ax = plt.subplots(rows,cols , figsize=(21, 28))
plt.subplots_adjust(hspace=0.1, wspace=0.15)  # Adjust space between subplots


for i in range(rows):
    for j in range(cols):
      if i*10 + j >= len(num):
          break
      plot_single(basis_save[i*cols + j,:], t, num[i*cols + j,:], ax[i,j], ss_res_save[i*cols + j])


# Remove x-tick labels for all plots except for the last row
for i in range(rows):
    for j in range(cols):
        if i < rows - 1:  # If not in the last row
            ax[i, j].set_xticklabels([])  # Remove x-tick labels
        else:
            ax[i, j].set_xticks(np.linspace(-np.pi,np.pi,7), ['12', '16', '20', '00', '04', '08', '12'])  # Set x-tick labels

# Add a legend at the center top of the figure
fig.legend(['Data', 'Fit', 'Basis 1', 'Basis 2', 'Basis 3'], loc='upper center', ncol=5, fontsize=16, bbox_to_anchor=(0.5, 0.9))


plt.show()

2.1.1 Fix the fit for some time series

Sometimes the optimization does not find the best fit. So plot a histogram of the model fits to check the accuracy.

Code
plt.hist(ss_res_save)
plt.title("Histogram of R squared values")
plt.xlabel("R squared / error")
plt.ylabel("Frequency")

plt.show()


plt.plot(ss_res_save, 'o')
plt.axhline(y = 150, color = 'r', linestyle = '--')
plt.ylabel("R squared")
plt.xlabel("Sample")

plt.show()

As expected the majority have a low R squared value (i.e., a good fit), but there are some that are high, suggesting that the optimization algorithm has not found the best fit / got stuck in a local minima.

Based on the histogram there are a handful that have R squared values of over 300 we will look at these individually.

Plot the model fits that have the worst R squared values, to see if these are just problems with the optimization algorithm.

Code
ss_res_save_over = ss_res_save > 300
print(f"Number of fits over threshold ($R^2 > 300$) is: {np.sum(ss_res_save_over)}")

num_over = num[ss_res_save_over]
popt_over = basis_save[ss_res_save_over]

for j in range(len(num_over)):
    plt.plot(t, num_over[j])
    plt.plot(t, func(t, *popt_over[j,:]) , 'r-')
    for i in range(0, len(popt_over[0,:]), 3):
        ctr = popt_over[j,i]
        amp = popt_over[j,i+1]
        wid = popt_over[j,i+2]
        y = amp*np.exp(wid*np.cos(t-ctr))/(2*np.pi*i0(wid))
        plt.plot(t, y,'--', alpha = 0.8)

    plt.show()
Number of fits over threshold ($R^2 > 300$) is: 6

I the case that we have poor fit we will recalculate the fit starting the optimization procedure at different initial guess and trying to find a better fit.

Code
def R_squared(time_series, fit):
    residuals = time_series - fit
    ss_res = np.sum(residuals**2)
    return ss_res

def fit_von_mises_opt(time_series, n_basis, threshold):

    '''
    Fit the von Mises basis to a single time series
    '''

    guess = np.zeros(3*n_basis)
    bounds = np.zeros([2, 3*n_basis])
    for k in range(n_basis):

        #guess[3*k] = np.random.uniform(-np.pi, np.pi) # mean location
        guess[3*k] = 0 # mean location
        guess[3*k+1] = 30 # amplitude/wieght
        guess[3*k+2] = 10 #dispersion

        # Peak location
        bounds[0, 3*k] = -np.pi
        bounds[1, 3*k] = np.pi

        # Amplitude (we will enforce this to be positive)
        bounds[0, 3*k + 1] = 0
        bounds[1, 3*k + 1] = 40

        # The dispersion, to give a more peaky distirbution increase the lower bound 'peaky distributions'
        bounds[0, 3*k + 2] = 0
        bounds[1, 3*k + 2] = 200

    popt_best, pcov = curve_fit(func, t, time_series, p0=guess, maxfev=100000000, bounds = bounds)
    #popt, pcov = curve_fit(func, t, time_series, p0=guess, maxfev=100000000)
    fit = func(t, *popt_best)

    # Calcualte residuals / acuradcy of fit
    accuracy = R_squared(time_series, fit)
    # Keep track of the best fit 
    best_fit = accuracy

    # If the fit is bad try the optimisation again with a different initial guess
    if accuracy > threshold:

        for i in range(100):
            for k in range(n_basis):
                guess[3*k] = np.random.uniform(-np.pi,np.pi) # mean location
                guess[3*k+1] = np.random.uniform(0,40) # amplitude/wieght
                guess[3*k+2] = np.random.uniform(0,200) #dispersion

            popt, pcov = curve_fit(func, t, time_series, p0=guess, maxfev=1000000000, bounds = bounds)
            fit = func(t, *popt)
            accuracy = R_squared(time_series, fit)
            if best_fit > accuracy:
                best_fit = accuracy
                popt_best = popt

            if accuracy < threshold:
                break

    return  popt_best

Use this new pipeline to fit all time series

Code
threshold = 150
n_basis = 3
ss_res_save = np.zeros(len(num))
basis_save = np.zeros([len(num), 3*n_basis])
basis1 = np.zeros(3)
basis2 = np.zeros(3)
basis3 = np.zeros(3)

for i in range(len(num)):
    popt = fit_von_mises_opt(num[i], n_basis, threshold)
    fit = func(t, *popt)
    residuals = num[i,:] - fit
    ss_res = np.sum(residuals**2)
    ss_res_save[i] = ss_res

    basis1, basis2, basis3 = order_basis(popt[0:3], popt[3:6], popt[6:9])

    basis_save[i,0:3] = basis1
    basis_save[i,3:6] = basis2
    basis_save[i,6:9] = basis3 

Plot histogram of model fits after manually fitting the ones that had a poor fit

Code
plt.hist(ss_res_save)
plt.title("Histogram of $R^2$ values")
plt.xlabel("$R^2$ / error")
plt.ylabel("Frequency")

plt.show()


plt.plot(ss_res_save, 'o')
plt.axhline(y = threshold, color = 'r', linestyle = '--')
plt.ylabel("$R^2$")
plt.xlabel("Sample")

plt.show()

Plot the model fits that have the worst R squared values, to see if these are just problems with the optimization algorithm.

Code
ss_res_save_over = ss_res_save > threshold
print(f"Number of fits over threshold ($R^2 > {threshold}$) is: {np.sum(ss_res_save_over)}")

num_over = num[ss_res_save_over]
popt_over = basis_save[ss_res_save_over]


for j in range(len(num_over)):
    fig, ax = plt.subplots()
    plot_single(popt_over[j,:], t, num_over[j], ax, ss_res=ss_res_save[ss_res_save_over][j])

    plt.show()
Number of fits over threshold ($R^2 > 150$) is: 5

From these it seems that the model fits them adequately and the poor \(R^2\) is caused by them having multiple peaks that would require more basis functions and not a failure of the optimization.

2.1.2 Visualize distribution of fitted model parameters

The shift parameter is the 1st parameter, the weight/amplitude is the 2nd and the width/dispersion is the 3rd.

Code
fig, ax = plt.subplots(3,3, figsize = (15,10))

for j in range(3):
    for i in range(3):
        ax[j,i].hist(basis_save[:,i + 3*j], bins = 20, density = True)
        ax[j,i].set_title(f"Distribution of parameter {i + 3*j}")




plt.show()

2.1.3 Save the parameters into CSV

Code
# Define the column names
column_names = ['PID', 'Cortisol', 'param_values', 'R2']

# Create a list to hold the new rows
rows = []

PIDs = df_common["PID"].unique()

for i in range(len(PIDs)):
    new_row = {
        'PID': PIDs[i], 
        'Cortisol': num[i],
        'param_values': basis_save[i, :],
        'R2': ss_res_save[i]
    }
    rows.append(new_row)

# Create the DataFrame from the list of rows
df_param = pd.DataFrame(rows, columns=column_names)

# Save the dataframe to csv
df_param.to_csv("Param_values.csv", index=False)

# Display the first few rows of dataframe
df_param.head()
PID Cortisol param_values R2
0 429 [0.14200000000000002, 0.14850000000000002, 0.1... [-1.436815171293944, 1.7401832225748495, 3.254... 1.934149
1 431 [2.613, 2.614, 2.614, 3.198, 2.864, 2.639, 2.2... [2.4039101674693075, 15.669434479084323, 2.159... 40.694714
2 434 [2.3775, 2.4575, 2.4965, 2.5035, 2.54049999999... [-2.799606838469967, 8.45341152778352, 0.75886... 1.793045
3 273 [2.7695, 2.6995, 2.5965, 2.2875, 2.10200000000... [2.052116028811062, 10.900317449103497, 7.7894... 37.497884
4 437 [2.507, 3.181, 3.893, 3.941, 5.028, 5.341, 6.8... [2.4675896282608365, 18.975765306182897, 3.939... 130.141250